﻿#pragma kernel Clear
#pragma kernel InsertInGrid
#pragma kernel SortByGrid
#pragma kernel ComputeDensity
#pragma kernel ApplyDensity

#include "InterlockedUtils.cginc"
#include "MathUtils.cginc"
#include "GridUtils.cginc"
#include "Simplex.cginc"
#include "Bounds.cginc"
#include "SolverParameters.cginc"
#include "FluidKernels.cginc"

RWStructuredBuffer<int> sortedToOriginal;

RWStructuredBuffer<uint> offsetInCell;
RWStructuredBuffer<uint> cellStart;    // start of each cell in the sorted item array.
RWStructuredBuffer<uint> cellCounts;     // number of item in each cell.
StructuredBuffer<aabb> solverBounds;

RWStructuredBuffer<float4> inputPositions; 
RWStructuredBuffer<float4> inputVelocities; 
RWStructuredBuffer<float4> inputColors; 
RWStructuredBuffer<float4> sortedPositions; 
RWStructuredBuffer<float4> sortedVelocities; 
RWStructuredBuffer<float4> fluidData; 

StructuredBuffer<uint> dispatch;

// each emitter has its own global radius, not possible to have foam emitters interact with each other.
float particleRadius;
float smoothingRadius;
float surfaceTension; 
float pressure;
float viscosity;
float4 volumeLightDirection;

float deltaTime;

[numthreads(128, 1, 1)]
void Clear (uint3 id : SV_DispatchThreadID)
{
    unsigned int i = id.x;
    if (i >= maxCells) return;
    
    // clear all cell counts to zero, and cell offsets to invalid.
    cellStart[i] = INVALID;
    cellCounts[i] = 0;
}

[numthreads(128, 1, 1)]
void InsertInGrid (uint3 id : SV_DispatchThreadID)
{
    unsigned int i = id.x;
    if (i >= dispatch[3]) return;
    
    uint cellIndex = GridHash(floor(inputPositions[i] / smoothingRadius).xyz);
    InterlockedAdd(cellCounts[cellIndex],1,offsetInCell[i]);
}

[numthreads(128, 1, 1)]
void SortByGrid (uint3 id : SV_DispatchThreadID)
{
    unsigned int i = id.x;
    if (i >= dispatch[3]) return;
    
    uint cellIndex = GridHash(floor(inputPositions[i] / smoothingRadius).xyz);
   
    uint sortedIndex = cellStart[cellIndex] + offsetInCell[i];
    sortedPositions[sortedIndex] = inputPositions[i];
    sortedVelocities[sortedIndex] = inputVelocities[i];
    sortedToOriginal[sortedIndex] = i;
}

[numthreads(128, 1, 1)]
void ComputeDensity (uint3 id : SV_DispatchThreadID)
{
    unsigned int i = id.x;
    if (i >= dispatch[3]) return;

    float4 positionA = inputPositions[i];

    int3 cellCoords = floor(inputPositions[i] / smoothingRadius).xyz;
    
    // self-contribution:
    float avgKernel = Poly6(0,smoothingRadius);
    float restVolume = pow(abs(particleRadius * 2),3-mode);
    float grad = restVolume * Spiky(0,smoothingRadius);
    
    float4 fluidDataA = float4(avgKernel,0,grad,grad*grad);
    
    float4 positionB;
    
    // iterate over neighborhood, calculate density and gradient.
    for (int k = 0; k < 27; ++k)
    {
        int3 neighborCoords = cellCoords + cellNeighborhood[k].xyz;
        uint cellIndex = GridHash(neighborCoords);
        uint start = cellStart[cellIndex];

        for (uint j = 0; j < cellCounts[cellIndex]; ++j)
        {
            positionB = sortedPositions[start + j];
            float3 r = (positionA - positionB).xyz;

            if (mode == 1) 
                r[2] = 0;

            float dist = length(r);

            if (dist > smoothingRadius || any(neighborCoords - floor(positionB / smoothingRadius).xyz)) continue;
            
            float grad = restVolume * Spiky(dist,smoothingRadius);
            fluidDataA += float4(Poly6(dist,smoothingRadius),0,grad,grad*grad);
        }
    }
    
    // self particle contribution to density and gradient:
    fluidDataA[3] += fluidDataA[2] * fluidDataA[2];
    
    // usually, we'd weight density by mass (density contrast formulation) by dividing by invMass. Then, multiply by invMass when
    // calculating the state equation (density / restDensity - 1, restDensity = mass / volume, so density * invMass * restVolume - 1
    // We end up with density / invMass * invMass * restVolume - 1, invMass cancels out.
    float constraint = max(0, fluidDataA[0] * restVolume - 1);

    // calculate lambda:
    fluidDataA[1] = -constraint / (fluidDataA[3] + EPSILON);
    
    fluidData[i] = fluidDataA;
}

[numthreads(128, 1, 1)]
void ApplyDensity (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
    if (i >= dispatch[3]) return;
    
    int3 cellCoords = floor(inputPositions[i] / smoothingRadius).xyz;

    float restVolume = pow(abs(particleRadius * 2),3-mode); 
    float neighborhoodVolume = pow(abs(smoothingRadius * 2),3-mode); 
    float4 positionA = inputPositions[i];
    float4 velocityA = inputVelocities[i];
    float4 fluidDataA = fluidData[i];

    float4 fluidDataB;
    float4 positionB;
    float4 velocityB;

    float4 pressureDelta = FLOAT4_ZERO;
    float4 viscAccel = FLOAT4_ZERO;
    float AO = 0;
    
    for (int k = 0; k < 27; ++k)
    {
        int3 neighborCoords = cellCoords + cellNeighborhood[k].xyz;
        uint cellIndex = GridHash(neighborCoords);
        uint start = cellStart[cellIndex];

        for (uint j = 0; j < cellCounts[cellIndex]; ++j)
        {
            positionB = sortedPositions[start + j];
            velocityB = sortedVelocities[start + j];
            fluidDataB = fluidData[sortedToOriginal[start + j]];

            float4 r = float4((positionA - positionB).xyz,0);

            if (mode == 1) 
                r[2] = 0;

            float dist = length(r);
           

            if (dist > smoothingRadius || any(neighborCoords - floor(positionB / smoothingRadius).xyz)) continue;

            float kern = Poly6(dist,smoothingRadius);
            float cAvg = Cohesion(dist,smoothingRadius);

            // XSPH viscosity:
            float4 relVel = float4((velocityB - velocityA).xyz,0);
            viscAccel += viscosity * relVel * kern * restVolume;
            
            float st = 0.2 * cAvg * surfaceTension;
            float scorrA = - st / (fluidDataA[3] + EPSILON);
            float scorrB = - st / (fluidDataB[3] + EPSILON);
            pressureDelta += r / (dist + EPSILON) * Spiky(dist,smoothingRadius) * ((fluidDataA[1] + scorrA) + (fluidDataB[1] + scorrB)) * restVolume;
            
            float4 v = r / (dist + EPSILON);
            AO += max(0, dot (volumeLightDirection.xyz, v.xyz) ) / (1 + dist) * restVolume; 
        }
    }
    
    float4 delta = pressure * pressureDelta;
    
    // modify position and velocity:
    inputPositions[i] = float4((positionA + delta).rgb, 2*PI * AO.x / neighborhoodVolume);
    inputVelocities[i] += float4(delta.xyz / deltaTime + viscAccel.xyz,0);
}